# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

from collections.abc import AsyncGenerator, Mapping
from typing import Any, Final, cast

import torch
from fastapi import Request
from typing_extensions import override

from vllm.engine.protocol import EngineClient
from vllm.entrypoints.chat_utils import ChatTemplateContentFormatOption
from vllm.entrypoints.logger import RequestLogger
from vllm.entrypoints.openai.protocol import (
    EMBED_DTYPE_TO_TORCH_DTYPE,
    EmbeddingChatRequest,
    EmbeddingCompletionRequest,
    EmbeddingRequest,
    EmbeddingResponse,
    EmbeddingResponseData,
    ErrorResponse,
    UsageInfo,
)
from vllm.entrypoints.openai.serving_engine import (
    EmbeddingServeContext,
    OpenAIServing,
    ServeContext,
    TextTokensPrompt,
)
from vllm.entrypoints.openai.serving_models import OpenAIServingModels
from vllm.entrypoints.openai.utils import encoding_pooling_output
from vllm.entrypoints.renderer import RenderConfig
from vllm.inputs.data import TokensPrompt as EngineTokensPrompt
from vllm.logger import init_logger
from vllm.outputs import (
    EmbeddingRequestOutput,
    PoolingOutput,
    PoolingRequestOutput,
    RequestOutput,
)
from vllm.pooling_params import PoolingParams
from vllm.utils import chunk_list
from vllm.utils.async_utils import merge_async_iterators

logger = init_logger(__name__)


class EmbeddingMixin(OpenAIServing):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        pooler_config = self.model_config.pooler_config

        # Avoid repeated attribute lookups
        self.supports_chunked_processing = bool(
            pooler_config and pooler_config.enable_chunked_processing
        )
        self.max_embed_len = (
            pooler_config.max_embed_len
            if pooler_config and pooler_config.max_embed_len
            else None
        )

    @override
    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> ErrorResponse | None:
        ctx = cast(EmbeddingServeContext, ctx)
        try:
            if ctx.request.embed_dtype not in EMBED_DTYPE_TO_TORCH_DTYPE:
                return self.create_error_response(
                    f"embed_dtype={ctx.request.embed_dtype!r} is not supported. "
                    f"Supported types: {EMBED_DTYPE_TO_TORCH_DTYPE.keys()}"
                )

            ctx.lora_request = self._maybe_get_adapters(ctx.request)

            tokenizer = await self.engine_client.get_tokenizer()
            renderer = self._get_renderer(tokenizer)

            if isinstance(ctx.request, EmbeddingChatRequest):
                (
                    _,
                    _,
                    ctx.engine_prompts,
                ) = await self._preprocess_chat(
                    ctx.request,
                    tokenizer,
                    ctx.request.messages,
                    chat_template=ctx.request.chat_template or ctx.chat_template,
                    chat_template_content_format=ctx.chat_template_content_format,
                    add_generation_prompt=ctx.request.add_generation_prompt,
                    continue_final_message=False,
                    add_special_tokens=ctx.request.add_special_tokens,
                )
            else:
                ctx.engine_prompts = await renderer.render_prompt(
                    prompt_or_prompts=ctx.request.input,
                    config=self._build_render_config(ctx.request),
                )
            return None
        except (ValueError, TypeError) as e:
            logger.exception("Error in preprocessing prompt inputs")
            return self.create_error_response(str(e))

    def _build_render_config(self, request: EmbeddingCompletionRequest) -> RenderConfig:
        # Set max_length based on chunked processing capability
        if self._should_use_chunked_processing(request):
            max_length = None
        else:
            max_length = self.max_embed_len or self.max_model_len

        return RenderConfig(
            max_length=max_length,
            truncate_prompt_tokens=request.truncate_prompt_tokens,
            add_special_tokens=request.add_special_tokens,
        )

    @override
    def _build_response(
        self,
        ctx: ServeContext,
    ) -> EmbeddingResponse | ErrorResponse:
        items: list[EmbeddingResponseData] = []
        num_prompt_tokens = 0

        final_res_batch_checked = cast(list[PoolingRequestOutput], ctx.final_res_batch)

        for idx, final_res in enumerate(final_res_batch_checked):
            item = EmbeddingResponseData(
                index=idx,
                embedding=encoding_pooling_output(
                    final_res, ctx.request.encoding_format, ctx.request.embed_dtype
                ),
            )
            prompt_token_ids = final_res.prompt_token_ids

            items.append(item)
            num_prompt_tokens += len(prompt_token_ids)

        usage = UsageInfo(
            prompt_tokens=num_prompt_tokens,
            total_tokens=num_prompt_tokens,
        )

        return EmbeddingResponse(
            id=ctx.request_id,
            created=ctx.created_time,
            model=ctx.model_name,
            data=items,
            usage=usage,
        )

    def _get_max_position_embeddings(self) -> int:
        """Get the model's effective maximum sequence length for chunking."""
        return self.model_config.max_model_len

    def _should_use_chunked_processing(self, request) -> bool:
        """Check if chunked processing should be used for this request."""
        return (
            isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest))
            and self.supports_chunked_processing
        )

    async def _process_chunked_request(
        self,
        ctx: EmbeddingServeContext,
        original_prompt: TextTokensPrompt,
        pooling_params,
        trace_headers,
        prompt_idx: int,
    ) -> list[AsyncGenerator[PoolingRequestOutput, None]]:
        """Process a single prompt using chunked processing."""
        generators: list[AsyncGenerator[PoolingRequestOutput, None]] = []
        token_ids = original_prompt["prompt_token_ids"]

        # Split into chunks using max_position_embeddings
        max_pos_embeddings = self._get_max_position_embeddings()
        # Process all chunks for MEAN aggregation
        for chunk_idx, chunk_tokens in enumerate(
            chunk_list(token_ids, max_pos_embeddings)
        ):
            # Create a request ID for this chunk
            chunk_request_id = f"{ctx.request_id}-prompt-{prompt_idx}-chunk-{chunk_idx}"

            # Create engine prompt for this chunk
            chunk_engine_prompt = EngineTokensPrompt(prompt_token_ids=chunk_tokens)

            # Create chunk request prompt for logging
            chunk_text = ""
            chunk_request_prompt = TextTokensPrompt(
                prompt=chunk_text, prompt_token_ids=chunk_tokens
            )

            # Log the chunk
            self._log_inputs(
                chunk_request_id,
                chunk_request_prompt,
                params=pooling_params,
                lora_request=ctx.lora_request,
            )

            # Create generator for this chunk and wrap it to return indices
            original_generator = self.engine_client.encode(
                chunk_engine_prompt,
                pooling_params,
                chunk_request_id,
                lora_request=ctx.lora_request,
                trace_headers=trace_headers,
                priority=getattr(ctx.request, "priority", 0),
            )

            generators.append(original_generator)

        return generators

    def _validate_input(
        self,
        request,
        input_ids: list[int],
        input_text: str,
    ) -> TextTokensPrompt:
        """Override to support chunked processing for embedding requests."""
        token_num = len(input_ids)

        # Note: EmbeddingRequest doesn't have max_tokens
        if isinstance(request, (EmbeddingCompletionRequest, EmbeddingChatRequest)):
            # Check if chunked processing is enabled for pooling models
            enable_chunked = self._should_use_chunked_processing(request)

            # Use max_position_embeddings for chunked processing decisions
            max_pos_embeddings = self._get_max_position_embeddings()

            # Determine the effective max length for validation
            if self.max_embed_len is not None:
                # Use max_embed_len for validation instead of max_model_len
                length_type = "maximum embedding input length"
                max_length_value = self.max_embed_len
            else:
                # Fall back to max_model_len validation (original behavior)
                length_type = "maximum context length"
                max_length_value = self.max_model_len

            validation_error_msg = (
                "This model's {length_type} is {max_length_value} tokens. "
                "However, you requested {token_num} tokens in the input for "
                "embedding generation. Please reduce the length of the input."
            )

            chunked_processing_error_msg = (
                "This model's {length_type} is {max_length_value} tokens. "
                "However, you requested {token_num} tokens in the input for "
                "embedding generation. Please reduce the length of the input "
                "or enable chunked processing."
            )

            # Check if input exceeds max length
            if token_num > max_length_value:
                raise ValueError(
                    validation_error_msg.format(
                        length_type=length_type,
                        max_length_value=max_length_value,
                        token_num=token_num,
                    )
                )

            # Check for chunked processing
            # when exceeding max_position_embeddings
            if token_num > max_pos_embeddings:
                if enable_chunked:
                    # Allow long inputs when chunked processing is enabled
                    logger.info(
                        "Input length %s exceeds max_position_embeddings "
                        "%s, will use chunked processing",
                        token_num,
                        max_pos_embeddings,
                    )
                else:
                    raise ValueError(
                        chunked_processing_error_msg.format(
                            length_type="maximum position embeddings length",
                            max_length_value=max_pos_embeddings,
                            token_num=token_num,
                        )
                    )

            return TextTokensPrompt(prompt=input_text, prompt_token_ids=input_ids)

        # For other request types, use the parent's implementation
        return super()._validate_input(request, input_ids, input_text)

    def _is_text_tokens_prompt(self, prompt) -> bool:
        """Check if a prompt is a TextTokensPrompt (has prompt_token_ids)."""
        return (
            isinstance(prompt, dict)
            and "prompt_token_ids" in prompt
            and "prompt_embeds" not in prompt
        )

    async def _create_single_prompt_generator(
        self,
        ctx: EmbeddingServeContext,
        engine_prompt: EngineTokensPrompt,
        pooling_params: PoolingParams,
        trace_headers: Mapping[str, str] | None,
        prompt_index: int,
    ) -> AsyncGenerator[RequestOutput | PoolingRequestOutput, None]:
        """Create a generator for a single prompt using standard processing."""
        request_id_item = f"{ctx.request_id}-{prompt_index}"

        self._log_inputs(
            request_id_item,
            engine_prompt,
            params=pooling_params,
            lora_request=ctx.lora_request,
        )

        # Return the original generator without wrapping
        return self.engine_client.encode(
            engine_prompt,
            pooling_params,
            request_id_item,
            lora_request=ctx.lora_request,
            trace_headers=trace_headers,
            priority=getattr(ctx.request, "priority", 0),
        )

    @override
    async def _prepare_generators(
        self,
        ctx: ServeContext,
    ) -> ErrorResponse | None:
        """Override to support chunked processing."""
        ctx = cast(EmbeddingServeContext, ctx)

        # Check if we should use chunked processing
        use_chunked = self._should_use_chunked_processing(ctx.request)

        # If no chunked processing needed, delegate to parent class
        if not use_chunked:
            return await super()._prepare_generators(ctx)

        # Custom logic for chunked processing
        generators: list[
            AsyncGenerator[RequestOutput | PoolingRequestOutput, None]
        ] = []

        try:
            trace_headers = (
                None
                if ctx.raw_request is None
                else await self._get_trace_headers(ctx.raw_request.headers)
            )

            pooling_params = self._create_pooling_params(ctx)
            if isinstance(pooling_params, ErrorResponse):
                return pooling_params

            # Verify and set the task for pooling params
            try:
                pooling_params.verify("embed", self.model_config)
            except ValueError as e:
                return self.create_error_response(str(e))

            if ctx.engine_prompts is None:
                return self.create_error_response("Engine prompts not available")

            max_pos_embeddings = self._get_max_position_embeddings()

            for i, engine_prompt in enumerate(ctx.engine_prompts):
                # Check if this specific prompt needs chunked processing
                if self._is_text_tokens_prompt(engine_prompt):
                    # Cast to TextTokensPrompt since we've verified
                    # prompt_token_ids
                    text_tokens_prompt = cast(TextTokensPrompt, engine_prompt)
                    if len(text_tokens_prompt["prompt_token_ids"]) > max_pos_embeddings:
                        # Use chunked processing for this prompt
                        chunk_generators = await self._process_chunked_request(
                            ctx, text_tokens_prompt, pooling_params, trace_headers, i
                        )
                        generators.extend(chunk_generators)
                        continue

                # Normal processing for short prompts or non-token prompts
                generator = await self._create_single_prompt_generator(
                    ctx, engine_prompt, pooling_params, trace_headers, i
                )
                generators.append(generator)

            ctx.result_generator = merge_async_iterators(*generators)

            return None

        except Exception as e:
            # TODO: Use a vllm-specific Validation Error
            return self.create_error_response(str(e))

    @override
    async def _collect_batch(
        self,
        ctx: ServeContext,
    ) -> ErrorResponse | None:
        """Collect and aggregate batch results
        with support for chunked processing.

        For chunked requests, performs online aggregation to
        minimize memory usage.
        For regular requests, collects results normally.
        """
        ctx = cast(EmbeddingServeContext, ctx)
        try:
            if ctx.engine_prompts is None:
                return self.create_error_response("Engine prompts not available")

            # Check if we used chunked processing
            use_chunked = self._should_use_chunked_processing(ctx.request)

            if not use_chunked:
                return await super()._collect_batch(ctx=ctx)

            if ctx.result_generator is None:
                return self.create_error_response("Result generator not available")

            # Online aggregation for chunked requests to
            # minimize memory usage
            # Track aggregation state for each prompt
            prompt_aggregators: dict[int, dict[str, Any]] = {}
            short_prompts_results: dict[int, PoolingRequestOutput] = {}

            async for result_idx, result in ctx.result_generator:
                if "-chunk-" in result.request_id:
                    # Extract prompt_idx from chunked request_id
                    parts = result.request_id.split("-")
                    try:
                        prompt_idx = int(parts[parts.index("prompt") + 1])
                    except (ValueError, IndexError):
                        # Fallback: extract from result_idx if parsing fails
                        prompt_idx = result_idx

                    # Initialize aggregator for this prompt if needed
                    if prompt_idx not in prompt_aggregators:
                        prompt_aggregators[prompt_idx] = {
                            "weighted_sum": None,
                            "total_weight": 0,
                            "chunk_count": 0,
                            "request_id": result.request_id.split("-chunk-")[0],
                        }

                    aggregator = prompt_aggregators[prompt_idx]

                    # MEAN pooling with online weighted averaging
                    # Ensure result is PoolingRequestOutput
                    # for embedding processing
                    if not isinstance(result, PoolingRequestOutput):
                        return self.create_error_response(
                            f"Expected PoolingRequestOutput for "
                            f"chunked embedding, got "
                            f"{type(result).__name__}"
                        )

                    # Handle both PoolingOutput and
                    # EmbeddingOutput types
                    if hasattr(result.outputs, "data"):
                        # PoolingOutput case
                        embedding_data = result.outputs.data
                    elif hasattr(result.outputs, "embedding"):
                        # EmbeddingOutput case -
                        # convert embedding list to tensor
                        embedding_data = result.outputs.embedding
                    else:
                        return self.create_error_response(
                            f"Unsupported output type: {type(result.outputs).__name__}"
                        )

                    if not isinstance(embedding_data, torch.Tensor):
                        embedding_data = torch.tensor(
                            embedding_data, dtype=torch.float32
                        )

                    if result.prompt_token_ids is None:
                        return self.create_error_response(
                            "prompt_token_ids cannot be None for chunked processing"
                        )
                    weight = len(result.prompt_token_ids)

                    weighted_embedding = embedding_data.to(dtype=torch.float32) * weight

                    if aggregator["weighted_sum"] is None:
                        # First chunk
                        aggregator["weighted_sum"] = weighted_embedding
                    else:
                        # Accumulate
                        aggregator["weighted_sum"] += weighted_embedding

                    aggregator["total_weight"] += weight
                    aggregator["chunk_count"] += 1
                else:
                    # Non-chunked result - extract prompt_idx from request_id
                    parts = result.request_id.split("-")
                    try:
                        # Last part should be prompt index
                        prompt_idx = int(parts[-1])
                    except (ValueError, IndexError):
                        prompt_idx = result_idx  # Fallback to result_idx

                    short_prompts_results[prompt_idx] = cast(
                        PoolingRequestOutput, result
                    )

            # Finalize aggregated results
            final_res_batch: list[PoolingRequestOutput | EmbeddingRequestOutput] = []
            num_prompts = len(ctx.engine_prompts)

            for prompt_idx in range(num_prompts):
                if prompt_idx in prompt_aggregators:
                    # Finalize MEAN aggregation for this chunked prompt
                    aggregator = prompt_aggregators[prompt_idx]

                    weighted_sum = aggregator["weighted_sum"]
                    total_weight = aggregator["total_weight"]

                    if (
                        weighted_sum is not None
                        and isinstance(weighted_sum, torch.Tensor)
                        and isinstance(total_weight, (int, float))
                        and total_weight > 0
                    ):
                        # Compute final mean embedding
                        final_embedding = weighted_sum / total_weight

                        # Create a PoolingRequestOutput
                        # for the aggregated result
                        pooling_output_data = PoolingOutput(data=final_embedding)

                        # Get original prompt token IDs for this prompt
                        original_prompt = ctx.engine_prompts[prompt_idx]
                        if not self._is_text_tokens_prompt(original_prompt):
                            return self.create_error_response(
                                f"Chunked prompt {prompt_idx} is not a TextTokensPrompt"
                            )

                        original_token_ids = cast(TextTokensPrompt, original_prompt)[
                            "prompt_token_ids"
                        ]

                        pooling_request_output = PoolingRequestOutput(
                            request_id=aggregator["request_id"],
                            prompt_token_ids=original_token_ids,
                            outputs=pooling_output_data,
                            finished=True,
                        )

                        final_res_batch.append(pooling_request_output)
                    else:
                        return self.create_error_response(
                            f"Failed to aggregate chunks for prompt {prompt_idx}"
                        )
                elif prompt_idx in short_prompts_results:
                    final_res_batch.append(
                        cast(PoolingRequestOutput, short_prompts_results[prompt_idx])
                    )
                else:
                    return self.create_error_response(
                        f"Result not found for prompt {prompt_idx}"
                    )

            ctx.final_res_batch = cast(
                list[RequestOutput | PoolingRequestOutput], final_res_batch
            )

            return None

        except Exception as e:
            return self.create_error_response(str(e))


class OpenAIServingEmbedding(EmbeddingMixin):
    request_id_prefix = "embd"

    def __init__(
        self,
        engine_client: EngineClient,
        models: OpenAIServingModels,
        *,
        request_logger: RequestLogger | None,
        chat_template: str | None,
        chat_template_content_format: ChatTemplateContentFormatOption,
        trust_request_chat_template: bool = False,
        log_error_stack: bool = False,
    ) -> None:
        super().__init__(
            engine_client=engine_client,
            models=models,
            request_logger=request_logger,
            log_error_stack=log_error_stack,
        )

        self.chat_template = chat_template
        self.chat_template_content_format: Final = chat_template_content_format
        self.trust_request_chat_template = trust_request_chat_template

    async def create_embedding(
        self,
        request: EmbeddingRequest,
        raw_request: Request | None = None,
    ) -> EmbeddingResponse | ErrorResponse:
        """
        Embedding API similar to OpenAI's API.

        See https://platform.openai.com/docs/api-reference/embeddings/create
        for the API specification. This API mimics the OpenAI Embedding API.
        """
        model_name = self.models.model_name()
        request_id = (
            f"{self.request_id_prefix}-"
            f"{self._base_request_id(raw_request, request.request_id)}"
        )

        ctx = EmbeddingServeContext(
            request=request,
            raw_request=raw_request,
            model_name=model_name,
            request_id=request_id,
            chat_template=self.chat_template,
            chat_template_content_format=self.chat_template_content_format,
        )

        return await super().handle(ctx)  # type: ignore

    @override
    def _create_pooling_params(
        self,
        ctx: ServeContext[EmbeddingRequest],
    ) -> PoolingParams | ErrorResponse:
        pooling_params = super()._create_pooling_params(ctx)
        if isinstance(pooling_params, ErrorResponse):
            return pooling_params

        try:
            pooling_params.verify("embed", self.model_config)
        except ValueError as e:
            return self.create_error_response(str(e))

        return pooling_params

    async def _preprocess(
        self,
        ctx: ServeContext,
    ) -> ErrorResponse | None:
        if isinstance(ctx.request, EmbeddingChatRequest):
            error_check_ret = self._validate_chat_template(
                request_chat_template=ctx.request.chat_template,
                chat_template_kwargs=ctx.request.chat_template_kwargs,
                trust_request_chat_template=self.trust_request_chat_template,
            )
            if error_check_ret is not None:
                return error_check_ret
        return await super()._preprocess(ctx)
